package gov.cms.grouper.snf;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import gov.cms.grouper.snf.lego.SnfComparator;
import gov.cms.grouper.snf.lego.SnfUtils;
import gov.cms.grouper.snf.model.table.CmgForPtOtRow;
import gov.cms.grouper.snf.model.table.NtaComorbidityRow;
import gov.cms.grouper.snf.model.table.SnfVersionRow;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Test;


public class SnfContextTest {

  public static final int version = 100;

  @Test
  public void testGet() {
    List<CmgForPtOtRow> items = this.initModel();
    CmgForPtOtRow expected = null;
    for (int i = 0; i < items.size(); i++) {
      if (i / 3 == 1 && i == 4) {
        expected = items.get(i);
      }
    }

    Map<String, Set<CmgForPtOtRow>> map = items.stream()
        .collect(Collectors.groupingBy((item) -> item.getClinicalCategory(), Collectors.toSet()));

    BiFunction<CmgForPtOtRow, Integer, Boolean> selector = (item, checkVer) -> {
      return item.getClinicalCategory().equals("test1") && item.getCmg().equals("cmg4")
          && SnfComparator.betweenInclusive(item.getLowVersion(), checkVer, item.getHighVersion());
    };

    CmgForPtOtRow actual = SnfTables.get(map, "test1", selector, 200);
    assertEquals(expected, actual);

    selector = (item, checkVer) -> {
      return false;
    };

    expected = null;
    actual = SnfTables.get(map, "test1", selector, 200);
    assertEquals(expected, actual);

    actual = SnfTables.get(map, "test-1", selector, 200);
    assertEquals(expected, actual);

  }

  @Test
  public void testInitMap() {
    Supplier<Set<CmgForPtOtRow>> loader = () -> SnfUtils.toSet(this.initModel());
    Map<Integer, SnfVersionRow> versionsTable = new HashMap<>();
    versionsTable.put(100, new SnfVersionRow(100, null, null, null));

    Map<String, Set<CmgForPtOtRow>> data =
        SnfTables.initMap(loader, (row) -> row.getClinicalCategory(), versionsTable);

    assertFalse(data.containsKey("test4"));
    assertTrue(data.containsKey("test1"));

    BiFunction<CmgForPtOtRow, Integer, Boolean> compareCondition = (item, checkVer) -> item
        .getClinicalCategory().equals("test3") && item.getCmg().equals("cmg9")
        && SnfComparator.betweenInclusive(item.getLowVersion(), checkVer, item.getHighVersion());

    Set<CmgForPtOtRow> expected = new HashSet<>();
    CmgForPtOtRow expectedRow = SnfTables.get(data, "test3", compareCondition, 100);
    expected.add(expectedRow);

    Set<CmgForPtOtRow> actual = data.get("test3");
    assertEquals(expected, actual);

    assertEquals(3, data.get("test0").size());
  }


  private List<CmgForPtOtRow> initModel() {
    CmgForPtOtRow.Builder builder = CmgForPtOtRow.Builder.of();
    List<CmgForPtOtRow> rows = new ArrayList<>();

    for (int i = 0; i < 10; i++) {
      String[] data = new String[] {"test" + (i / 3), "100", null, "2", "8", "cmg" + i};
      CmgForPtOtRow item = builder.get(data);
      rows.add(item);
    }

    return rows;
  }


  @Test
  public void testSelectAll() {
    List<Integer> expected = Arrays.asList(1, 2, 3, 4);
    expected.sort(SnfComparator.NULL_HIGH);
    final List<String> mdsItems = Arrays.asList("I1300", "I2900", "O0100H2", "K0710A2");
    List<Predicate<NtaComorbidityRow>> conditions = new ArrayList<>();

    for (final String code : mdsItems) {
      Predicate<NtaComorbidityRow> cond = (item) -> item.getMdsItems().contains(code);
      conditions.add(cond);
    }


    Map<String, Set<NtaComorbidityRow>> table = new HashMap<>();
    Set<NtaComorbidityRow> rows = SnfUtils.toSet(
        new NtaComorbidityRow("test", SnfUtils.toOrderedSet("I1300", "K0710A2", "abc"), 1),
        new NtaComorbidityRow("test", SnfUtils.toOrderedSet("I2900", "abc"), 2),
        new NtaComorbidityRow("test", SnfUtils.toOrderedSet("O0100H2", "abc"), 3),
        new NtaComorbidityRow("test", SnfUtils.toOrderedSet("K0710A2", "abc"), 4),
        new NtaComorbidityRow("test", SnfUtils.toOrderedSet("no here", "abc"), 5));

    table.put("key", rows);

    List<Integer> actual =
        SnfTables.selectAll(table, SnfUtils.or(conditions), (row) -> row.getPoint());
    actual.sort(SnfComparator.NULL_HIGH);

    assertEquals(expected, actual);

  }

}
